{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interpretable Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we will fit classification explainable boosting machine (EBM), LogisticRegression, and ClassificationTree models. After fitting them, we will use their glassbox nature to understand their global and local explanations.\n",
    "\n",
    "This notebook can be found in our [**_examples folder_**](https://github.com/interpretml/interpret/tree/develop/docs/interpret/python/examples) on GitHub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install interpret if not already installed\n",
    "try:\n",
    "    import interpret\n",
    "except ModuleNotFoundError:\n",
    "    !pip install --quiet interpret pandas scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from interpret import show\n",
    "from interpret.perf import ROC\n",
    "\n",
    "from interpret import set_visualize_provider\n",
    "from interpret.provider import InlineProvider\n",
    "set_visualize_provider(InlineProvider())\n",
    "\n",
    "df = pd.read_csv(\n",
    "    \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n",
    "    header=None)\n",
    "df.columns = [\n",
    "    \"Age\", \"WorkClass\", \"fnlwgt\", \"Education\", \"EducationNum\",\n",
    "    \"MaritalStatus\", \"Occupation\", \"Relationship\", \"Race\", \"Gender\",\n",
    "    \"CapitalGain\", \"CapitalLoss\", \"HoursPerWeek\", \"NativeCountry\", \"Income\"\n",
    "]\n",
    "X = df.iloc[:, :-1]\n",
    "y = (df.iloc[:, -1] == \" >50K\").astype(int)\n",
    "\n",
    "seed = 42\n",
    "np.random.seed(seed)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Explore the dataset</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret.data import ClassHistogram\n",
    "\n",
    "hist = ClassHistogram().explain_data(X_train, y_train, name='Train Data')\n",
    "show(hist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Train the Explainable Boosting Machine (EBM)</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret.glassbox import ExplainableBoostingClassifier\n",
    "\n",
    "ebm = ExplainableBoostingClassifier()\n",
    "ebm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>EBMs are glassbox models, so we can edit them</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# post-process monotonize the Age feature\n",
    "ebm.monotonize(\"Age\", increasing=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Global Explanations: What the model learned overall</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ebm_global = ebm.explain_global(name='EBM')\n",
    "show(ebm_global)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Local Explanations: How an individual prediction was made</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ebm_local = ebm.explain_local(X_test[:5], y_test[:5], name='EBM')\n",
    "show(ebm_local, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Evaluate EBM performance</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ebm_perf = ROC(ebm).explain_perf(X_test, y_test, name='EBM')\n",
    "show(ebm_perf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Let's test out a few other Explainable Models</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret.glassbox import LogisticRegression, ClassificationTree\n",
    "\n",
    "# We have to transform categorical variables to use Logistic Regression and Decision Tree\n",
    "X = pd.get_dummies(X, prefix_sep='.').astype(float)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)\n",
    "\n",
    "lr = LogisticRegression(random_state=seed, penalty='l1', solver='liblinear')\n",
    "lr.fit(X_train, y_train)\n",
    "\n",
    "tree = ClassificationTree()\n",
    "tree.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Compare performance using the Dashboard</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_perf = ROC(lr).explain_perf(X_test, y_test, name='Logistic Regression')\n",
    "show(lr_perf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_perf = ROC(tree).explain_perf(X_test, y_test, name='Classification Tree')\n",
    "show(tree_perf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Glassbox: All of our models have global and local explanations</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_global = lr.explain_global(name='Logistic Regression')\n",
    "show(lr_global)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_global = tree.explain_global(name='Classification Tree')\n",
    "show(tree_global)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
